{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# python notebook for Make Your Own Neural Network\n",
    "# code for a 3-layer neural network, and code for learning the MNIST dataset\n",
    "# this version trains using the MNIST dataset, then tests on our own images\n",
    "# (c) Tariq Rashid, 2016\n",
    "# license is GPLv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy\n",
    "# scipy.special for the sigmoid function expit()\n",
    "import scipy.special\n",
    "# library for plotting arrays\n",
    "import matplotlib.pyplot\n",
    "# ensure the plots are inside this notebook, not an external window\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# helper to load data from PNG image files\n",
    "import scipy.misc\n",
    "# glob helps select multiple files using patterns\n",
    "import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# neural network class definition\n",
    "class neuralNetwork:\n",
    "    \n",
    "    \n",
    "    # initialise the neural network\n",
    "    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):\n",
    "        # set number of nodes in each input, hidden, output layer\n",
    "        self.inodes = inputnodes\n",
    "        self.hnodes = hiddennodes\n",
    "        self.onodes = outputnodes\n",
    "        \n",
    "        # link weight matrices, wih and who\n",
    "        # weights inside the arrays are w_i_j, where link is from node i to node j in the next layer\n",
    "        # w11 w21\n",
    "        # w12 w22 etc \n",
    "        self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))\n",
    "        self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))\n",
    "\n",
    "        # learning rate\n",
    "        self.lr = learningrate\n",
    "        \n",
    "        # activation function is the sigmoid function\n",
    "        self.activation_function = lambda x: scipy.special.expit(x)\n",
    "        \n",
    "        pass\n",
    "\n",
    "    \n",
    "    # train the neural network\n",
    "    def train(self, inputs_list, targets_list):\n",
    "        # convert inputs list to 2d array\n",
    "        inputs = numpy.array(inputs_list, ndmin=2).T\n",
    "        targets = numpy.array(targets_list, ndmin=2).T\n",
    "        \n",
    "        # calculate signals into hidden layer\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        # calculate the signals emerging from hidden layer\n",
    "        hidden_outputs = self.activation_function(hidden_inputs)\n",
    "        \n",
    "        # calculate signals into final output layer\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        # calculate the signals emerging from final output layer\n",
    "        final_outputs = self.activation_function(final_inputs)\n",
    "        \n",
    "        # output layer error is the (target - actual)\n",
    "        output_errors = targets - final_outputs\n",
    "        # hidden layer error is the output_errors, split by weights, recombined at hidden nodes\n",
    "        hidden_errors = numpy.dot(self.who.T, output_errors) \n",
    "        \n",
    "        # update the weights for the links between the hidden and output layers\n",
    "        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)), numpy.transpose(hidden_outputs))\n",
    "        \n",
    "        # update the weights for the links between the input and hidden layers\n",
    "        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))\n",
    "        \n",
    "        pass\n",
    "\n",
    "    \n",
    "    # query the neural network\n",
    "    def query(self, inputs_list):\n",
    "        # convert inputs list to 2d array\n",
    "        inputs = numpy.array(inputs_list, ndmin=2).T\n",
    "        \n",
    "        # calculate signals into hidden layer\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        # calculate the signals emerging from hidden layer\n",
    "        hidden_outputs = self.activation_function(hidden_inputs)\n",
    "        \n",
    "        # calculate signals into final output layer\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        # calculate the signals emerging from final output layer\n",
    "        final_outputs = self.activation_function(final_inputs)\n",
    "        \n",
    "        return final_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# number of input, hidden and output nodes\n",
    "input_nodes = 784\n",
    "hidden_nodes = 200\n",
    "output_nodes = 10\n",
    "\n",
    "# learning rate\n",
    "learning_rate = 0.1\n",
    "\n",
    "# create instance of neural network\n",
    "n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# load the mnist training data CSV file into a list\n",
    "training_data_file = open(\"mnist_dataset/mnist_train.csv\", 'r')\n",
    "training_data_list = training_data_file.readlines()\n",
    "training_data_file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# train the neural network\n",
    "\n",
    "# epochs is the number of times the training data set is used for training\n",
    "epochs = 10\n",
    "\n",
    "for e in range(epochs):\n",
    "    # go through all records in the training data set\n",
    "    for record in training_data_list:\n",
    "        # split the record by the ',' commas\n",
    "        all_values = record.split(',')\n",
    "        # scale and shift the inputs\n",
    "        inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01\n",
    "        # create the target output values (all 0.01, except the desired label which is 0.99)\n",
    "        targets = numpy.zeros(output_nodes) + 0.01\n",
    "        # all_values[0] is the target label for this record\n",
    "        targets[int(all_values[0])] = 0.99\n",
    "        n.train(inputs, targets)\n",
    "        pass\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading ...  my_own_images/2828_my_own_2.png\n",
      "0.01\n",
      "1.0\n",
      "loading ...  my_own_images/2828_my_own_3.png\n",
      "0.01\n",
      "1.0\n",
      "loading ...  my_own_images/2828_my_own_4.png\n",
      "0.01\n",
      "0.930118\n",
      "loading ...  my_own_images/2828_my_own_5.png\n",
      "0.01\n",
      "0.868\n",
      "loading ...  my_own_images/2828_my_own_6.png\n",
      "0.01\n",
      "1.0\n",
      "loading ...  my_own_images/2828_my_own_noisy_6.png\n",
      "0.145882\n",
      "0.774824\n"
     ]
    }
   ],
   "source": [
    "# our own image test data set\n",
    "our_own_dataset = []\n",
    "\n",
    "# load the png image data as test data set\n",
    "for image_file_name in glob.glob('my_own_images/2828_my_own_*.png'):\n",
    "    \n",
    "    # use the filename to set the correct label\n",
    "    label = int(image_file_name[-5:-4])\n",
    "    \n",
    "    # load image data from png files into an array\n",
    "    print (\"loading ... \", image_file_name)\n",
    "    img_array = scipy.misc.imread(image_file_name, flatten=True)\n",
    "    \n",
    "    # reshape from 28x28 to list of 784 values, invert values\n",
    "    img_data  = 255.0 - img_array.reshape(784)\n",
    "    \n",
    "    # then scale data to range from 0.01 to 1.0\n",
    "    img_data = (img_data / 255.0 * 0.99) + 0.01\n",
    "    print(numpy.min(img_data))\n",
    "    print(numpy.max(img_data))\n",
    "    \n",
    "    # append label and image data  to test data set\n",
    "    record = numpy.append(label,img_data)\n",
    "    our_own_dataset.append(record)\n",
    "    \n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.03735814]\n",
      " [ 0.02521326]\n",
      " [ 0.00242829]\n",
      " [ 0.00282901]\n",
      " [ 0.10258703]\n",
      " [ 0.1166545 ]\n",
      " [ 0.15649976]\n",
      " [ 0.08053335]\n",
      " [ 0.00169163]\n",
      " [ 0.00124463]]\n",
      "network says  6\n",
      "match!\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADuJJREFUeJzt3X+MVfWZx/HPwwwODqN10sgPRazEtJs0GsOmJhtIvJXY\nGtME0z9c7WoQf8ZgS7aaoGjCRU1s/cMfG9M/pLRgrWlpk4rVZNcScxXTdCG7pYvbQRsNFq0Mrihh\nBIWFZ/+YC3tnnPs9lzn33Hvgeb+SiWfOc86c7z3yueee+z3ne8zdBSCWKd1uAIDOI/hAQAQfCIjg\nAwERfCAggg8ElCv4Znalme0wszfNbEW7GgWgWDbZfnwzmyLpTUmLJP1N0lZJ17r7jnHLcaEA0CXu\nbhPNz3PEv1TSX9z9HXc/LOkXkhY32fjxn1WrVo35vWw/tO/UbV+Z21ZE+1LyBP9cSbsafn+3Pg9A\nyfHlHhBQb45135M0t+H3OfV5n1OtVo9Pn3XWWTk2WbxKpdLtJiTRvskrc9uk/O2r1Wqq1WotLZvn\ny70eSW9o9Mu99yVtkXSduw+NW84nuw0Ak2dm8iZf7k36iO/uR8zsTkkvafSUYe340AMop0kf8Vve\nAEd8oCsKOeIjhqw37b179zatDQwMJNft6+ubVJuQH9/qAwERfCAggg8ERPCBgAg+EBDBBwIi+EBA\n9OOf4rL64d97b8LbK4576qmnkvWlS5c2rWX146N7OOIDARF8ICCCDwRE8IGACD4QEMEHAuJ+/OCO\nHj2arPf2pnt8n3nmmaa16667Lrmu2YS3iqNNUvfjc8QHAiL4QEAEHwiI4AMBEXwgIIIPBETwgYDo\nxz/FZe37lStXJus9PT3J+rJly5rWZs+enVwXxaIfH8AYBB8IiOADARF8ICCCDwRE8IGACD4QUK5+\nfDPbKWmfpKOSDrv7pRMsQz9+gbL27bp165L12267LVn/+OOPk/X+/v6mNe63765UP37ecfWPSqq4\n+0c5/w6ADsr7Ud/a8DcAdFje0Lqk35nZVjO7tR0NAlC8vB/1F7j7+2Z2tkbfAIbc/bXxC1Wr1ePT\nlUpFlUol52YBjFer1VSr1Vpatm036ZjZKkn73f3RcfP5cq9AfLmHZgq5ScfM+s1soD49XdI3JL0+\n2b8HoHPyfNSfKek3Zub1v/Nzd3+pPc0CUCTuxy+5rH338MMPJ+v3339/sn7gwIFkva+vL1nn43x5\ncT8+gDEIPhAQwQcCIvhAQAQfCIjgAwERfCCgvNfqI6esfvpPPvkkWf/www+T9VdeeSVZnzKF9/6I\n+L8OBETwgYAIPhAQwQcCIvhAQAQfCIjgAwHRj19yS5cuTdaznkG/cOHCdjYHpwiO+EBABB8IiOAD\nARF8ICCCDwRE8IGACD4QEP34JXf48OFk/eyzz07WGfceE+GIDwRE8IGACD4QEMEHAiL4QEAEHwiI\n4AMBZfbjm9laSd+SNOzuF9fnDUr6paTzJe2UdI277yuwnaWVNS7+oUOHkvWDBw8m61OnTk3Wd+3a\nlaxffvnlyfrmzZuT9bvvvjtZX7RoUdPaZZddllw367WhOK0c8X8q6Zvj5t0jaZO7f0XSy5LubXfD\nABQnM/ju/pqkj8bNXixpfX16vaSr29wuAAWa7Dn+DHcfliR33y1pRvuaBKBo7bpWP3miW61Wj09X\nKhVVKpU2bRbAMbVaTbVaraVlJxv8YTOb6e7DZjZL0p7Uwo3BB1CM8QfV1atXN1221Y/6Vv855nlJ\nN9anl0jaeCINBNBdmcE3s2cl/V7Sl83sr2a2VNIPJF1hZm9IWlT/HcBJwrL6oXNvwMyL3kYeWW07\ncuRIsr5nT/IsJ/P59r296bOtefPm5apfddVVyfrg4GCyPn/+/GT9vvvua1pL9fFL0mOPPZas79+/\nP1mfPn16sh79OgEzk7tPOCADV+4BARF8ICCCDwRE8IGACD4QEMEHAiL4QECMq58hq58/a9z68847\nL1l/9dVXk/XTTz89WR8aGkrWi+7LTo37f8sttyTXffDBB5P1rGsoMHkc8YGACD4QEMEHAiL4QEAE\nHwiI4AMBEXwgoPD342cpuu333psembynpydZf+ihh5L1rOsM8krtnwceeCC57ltvvZWsr127dlJt\nOob78bkfH0ADgg8ERPCBgAg+EBDBBwIi+EBABB8IiPvxM2T1g2f182/ZsiVZ37ZtW7K+YcOGZL2b\n/fSSdOjQoaa1rPvxFy5cmKy/+OKLyfoVV1yRrEfvx0/hiA8ERPCBgAg+EBDBBwIi+EBABB8IiOAD\nAWX245vZWknfkjTs7hfX562SdKukYw+HX+nu/1pYK09iWX3RmzdvTtbPOOOMdjan7Xp7m/8TOuec\nc5LrDg4OJutZYxH09/cn62iulSP+TyV9c4L5j7r7/PoPoQdOIpnBd/fXJH00QanYS8YAFCbPOf6d\nZrbNzH5sZl9oW4sAFG6y1+r/SNID7u5m9pCkRyXd3GzharV6fLpSqahSqUxyswCaqdVqqtVqLS07\nqeC7+wcNv66R9NvU8o3BB1CM8QfV1atXN1221Y/6poZzejOb1VD7tqTXT6iFALqqle68ZyVVJH3R\nzP4qaZWkr5vZJZKOStop6fYC2wigzRhXP6es17ZmzZpkPet++ptvbvrViSRpypRir8HKen2p+sjI\nSHLd5557Lll/8sknk/VNmzYl62eeeWayfqpjXH0AYxB8ICCCDwRE8IGACD4QEMEHAiL4QECMq58h\nz7jykvTII48k63fddVeyXvS4+Xl99tlnTWtZ/fg33HBDsr5///5k/cILL0zW33777WR9YGAgWT+V\nccQHAiL4QEAEHwiI4AMBEXwgIIIPBETwgYDox88p637466+/Pll/4YUXkvXbb0+PcVJ0P3/W3582\nbVrT2uzZs3Nte8mSJcl61vhyWePyR8YRHwiI4AMBEXwgIIIPBETwgYAIPhAQwQcCYlz9DHnGlZey\n79dfvnx5sr5ixYpkfd68ecl6mWXtu5tuuilZf/zxx5P1rHH1yz7WQV6Mqw9gDIIPBETwgYAIPhAQ\nwQcCIvhAQAQfCCizH9/M5kh6WtJMSUclrXH3fzGzQUm/lHS+pJ2SrnH3fROsf1L342fJem2ffvpp\nsr5gwYJkPWvs+YsuuihZz+qrzhpbvrc3PWTD1KlTm9YOHz6cXDfrte3YsSNZHxoaStYvuOCCZP20\n005L1k92efvx/1fS9939q5L+QdIyM/s7SfdI2uTuX5H0sqR729VgAMXKDL6773b3bfXpEUlDkuZI\nWixpfX2x9ZKuLqqRANrrhM7xzexLki6R9AdJM919WBp9c5A0o92NA1CMlsfcM7MBSb+WtNzdR8xs\n/Mlt05PdarV6fLpSqahSqZxYKwFkqtVqmeMQHtNS8M2sV6Oh/5m7b6zPHjazme4+bGazJO1ptn5j\n8AEUY/xBdfXq1U2XbfWj/k8k/dndn2iY97ykG+vTSyRtHL8SgHLKPOKb2QJJ/yRpu5n9UaMf6VdK\n+qGkDWZ2k6R3JF1TZEMBtA/34+eU9dqOHDmSrK9bty5Z3759e7K+b9/nLp0YY3h4OFnv7+9P1vfu\n3ZusT58+vWkt6374rH13xx13JOtz585N1rPG9e/r60vWT3bcjw9gDIIPBETwgYAIPhAQwQcCIvhA\nQAQfCIh+/JyyXtuBAweS9axx9/PeL5/XwYMHk/Vp06Y1rZV93Pqyty8v+vEBjEHwgYAIPhAQwQcC\nIvhAQAQfCIjgAwHRj59T3vvxe3p6cm2/6L7oPP/vTvV+8rKjHx/AGAQfCIjgAwERfCAggg8ERPCB\ngAg+EFCxN3MHkNVXndVPX/a+7rK3D5PDER8IiOADARF8ICCCDwRE8IGACD4QUGbwzWyOmb1sZv9t\nZtvN7Lv1+avM7F0z+8/6z5XFNxdAO2Tej29msyTNcvdtZjYg6T8kLZb0j5L2u/ujGeuf0vfjA2WV\nuh8/8wIed98taXd9esTMhiSde+xvt62VADrmhM7xzexLki6R9O/1WXea2TYz+7GZfaHNbQNQkJaD\nX/+Y/2tJy919RNKPJM1z90s0+okg+ZEfQHm0dK2+mfVqNPQ/c/eNkuTuHzQsskbSb5utX61Wj09X\nKhVVKpVJNBVASq1WU61Wa2nZlgbbNLOnJf2Pu3+/Yd6s+vm/zOyfJX3N3b8zwbp8uQd0QerLvVa+\n1V8g6VVJ2yV5/WelpO9o9Hz/qKSdkm539+EJ1if4QBfkCn4bNk7wgS5geG0AYxB8ICCCDwRE8IGA\nCD4QEMEHAiL4QEAEHwiI4AMBEXwgIIIPBETwgYA6HvxW7xfuFtqXT5nbV+a2SZ1tH8Efh/blU+b2\nlblt0ikefADdR/CBgDoyEEehGwDQVNdG4AFQPnzUBwIi+EBAHQu+mV1pZjvM7E0zW9Gp7bbKzHaa\n2Z/M7I9mtqUE7VlrZsNm9l8N8wbN7CUze8PM/q2bTy9q0r7SPEh1goe9fq8+vxT7sNsPo+3IOb6Z\nTZH0pqRFkv4maauka919R+Ebb5GZvS3p7939o263RZLMbKGkEUlPu/vF9Xk/lPShuz9Sf/McdPd7\nStS+VWrhQaqdkHjY61KVYB/mfRhtXp064l8q6S/u/o67H5b0C42+yDIxlejUx91fkzT+TWixpPX1\n6fWSru5ooxo0aZ9Ukgepuvtud99Wnx6RNCRpjkqyD5u0r2MPo+3UP/RzJe1q+P1d/f+LLAuX9Dsz\n22pmt3a7MU3MOPbQkvpTjGZ0uT0TKd2DVBse9voHSTPLtg+78TDa0hzhSmCBu8+XdJWkZfWPsmVX\ntr7Y0j1IdYKHvY7fZ13dh916GG2ngv+epLkNv8+pzysNd3+//t8PJP1Go6cnZTNsZjOl4+eIe7rc\nnjHc/YOGxyatkfS1brZnooe9qkT7sNnDaDuxDzsV/K2SLjSz883sNEnXSnq+Q9vOZGb99Xdemdl0\nSd+Q9Hp3WyVp9Fyv8XzveUk31qeXSNo4foUOG9O+epCO+ba6vw9/IunP7v5Ew7wy7cPPta9T+7Bj\nV+7VuyWe0OibzVp3/0FHNtwCM7tAo0d51+ijw3/e7faZ2bOSKpK+KGlY0ipJz0n6laTzJL0j6Rp3\n/7hE7fu6WniQaofa1+xhr1skbVCX92Heh9Hm3j6X7ALx8OUeEBDBBwIi+EBABB8IiOADARF8ICCC\nDwRE8IGA/g9siGGfSZcckAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x1097b67f0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# test the neural network withour own images\n",
    "\n",
    "# record to test\n",
    "item = 4\n",
    "\n",
    "# plot image\n",
    "matplotlib.pyplot.imshow(our_own_dataset[item][1:].reshape(28,28), cmap='Greys', interpolation='None')\n",
    "\n",
    "# correct answer is first value\n",
    "correct_label = our_own_dataset[item][0]\n",
    "# data is remaining values\n",
    "inputs = our_own_dataset[item][1:]\n",
    "\n",
    "# query the network\n",
    "outputs = n.query(inputs)\n",
    "print (outputs)\n",
    "\n",
    "# the index of the highest value corresponds to the label\n",
    "label = numpy.argmax(outputs)\n",
    "print(\"network says \", label)\n",
    "# append correct or incorrect to list\n",
    "if (label == correct_label):\n",
    "    print (\"match!\")\n",
    "else:\n",
    "    print (\"no match!\")\n",
    "    pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
